Qual o tamanho da sua incerteza?

Pacote recalibratiNN

Carolina Musso

LatinR 2024

Introdução: Redes Neurais hoje

  • Deveriam ser capaz de quantificar sua incerteza.
  • As RNs podem ser construídas para produzir resultados probabilísticos:
    • Otimizadas pela log-verossimilhança.
    • Como qualquer modelo, pode ser mal calibrado.
      • Um IC de 95% deve conter 95% da saída real.
      • \(\mathbb{P}(Y \leq \hat{F_Y}^{-1}(p))= p , \forall ~ p \in [0,1]\)

Note

Se otimizado pelo MSE, estarei assumindo uma distribuição normal.

Observando a falta de calibração

Considere um conjunto de dados sintéticos \((x_i, y_i), i \in (1, ..., n)\) gerados por um modelo não linear heteroscedástico:

\[ x_i \sim Uniform(1,10)\\ \]

\[ y_i|x_i \sim Normal(\mu = f_1(x_i), \sigma= f_2(x_i)) \\ f_1(x) = 5x^2 + 10 ~; ~ f_2(x) = 30x \] E então ajustamos um modelo de regressão linear simples…

\[ \hat{y}_i = \beta_0 + \beta_1 x_i +\epsilon_i, ~\epsilon_i ~ iid \sim N(0,\sigma) \]

Observando a falta de calibração

Uma regressão linear simples, só para aquecer.

  • Global Coverage: 94.45%.

Valores PIT

  • Histograma dos valores da Transformação Integral de Probabilidade (PIT).

  • Seja \(F_Y(y)\) a função de distribuição acumulada de uma variável aleatória contínua Y, então:

\[U = F_Y (Y ) ∼ Uniform(0, 1)\]

  • Em particular, se \(Y \sim Normal(\mu, \sigma)\):

\[Y = F_Y^{-1} (U) ∼ Normal(\mu, \sigma)\]

Visualizando os valores PIT

Recalibração

Pacotes disponíveis

  • R: probably

  • Python: ml_insights

  • Apenas calibração global, focada em problemas de classificação e aplicável apenas no espaço de covariáveis.

Método:

  • Torres et al (2024): Calibração local para regressões e em diversas representações do espaço de covariáveis: útil para Redes Neurais Artificiais (RNAs).

Algoritmo

O pacote está disponível

  • No GitHub

  • No CRAN

  • por agora, apenas para modelos de regressão gaussiana.

The recalibratiNN package

  • Sete funções e 10 dependências
Function Description Arguments
PIT_global Calculates PIT values for the entire dataset ycal, yhat, mse
PIT_local Calculates PIT values for each cluster xcal, ycal, yhat, mse, clusters, p_neighbours, PIT
gg_PIT_global Plots PIT values histogram pit, type, fill, alpha, print_p
gg_PIT_local Plots PIT values densities for kmeans clusters pit_local, alpha, linewidth, pal, facet
recalibrate Recalibrates the model yhat_new, space_new, space_cal, pit_values, mse, type, p_neighbours, epsilon

Visualizando a falta de calibração

  • With a Neural Network example
  • Com um exemplo de Rede Neural

Data

set.seed(42)   # The Answer to the Ultimate Question of Life, The Universe, and Everything

n <- 10000

x <- cbind(x1 = runif(n, -3, 3),
           x2 = runif(n, -5, 5))

mu_fun <- function(x) {
  abs(x[,1]^3 - 50*sin(x[,2]) + 30)}

mu <- mu_fun(x)
y <- rnorm(n, 
           mean = mu, 
           sd=20*(abs(x[,2]/(x[,1]+ 10))))

split1 <- 0.6
split2 <- 0.8

x_train <- x[1:(split1*n),]
y_train <- y[1:(split1*n)]

x_cal  <- x[(split1*n+1):(n*split2),]
y_cal  <- y[(split1*n+1):(n*split2)]

x_test <- x[(split2*n+1):n,]
y_test  <- y[(split2*n+1):n]

Keras

model_nn <- keras_model_sequential()

model_nn |> 
  layer_dense(input_shape=2,
              units=800,
              use_bias=T,
              activation = "relu",
              kernel_initializer="random_normal",
              bias_initializer = "zeros") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units=800,
              use_bias=T,
              activation = "relu",
              kernel_initializer="random_normal",
              bias_initializer = "zeros") |> 
  layer_dropout(rate = 0.1) |> 
  layer_dense(units=800,
              use_bias=T,
              activation = "relu",
              kernel_initializer="random_normal",
              bias_initializer = "zeros") |> 
   layer_batch_normalization() |> 
  layer_dense(units = 1,
              activation = "linear",
              kernel_initializer = "zeros",
              bias_initializer = "zeros")

model_nn |> 
  compile(optimizer=optimizer_adam( ),
    loss = "mse")

model_nn |> 
  fit(x = x_train, 
      y = y_train,
      validation_data = list(x_cal, y_cal),
      callbacks = callback_early_stopping(
        monitor = "val_loss",
        patience = 20,
        restore_best_weights = T),
      batch_size = 128,
      epochs = 1000)


y_hat_cal <- predict(model_nn, x_cal)
y_hat_test <- predict(model_nn, x_test)

Descalibração global

pit <- PIT_global(ycal = y_cal, # outcomes of calibration set
                  yhat = y_hat_cal, # predictions of calibration
                  mse = MSE_cal) # MSE of calibration set 
                       
gg_PIT_global(pit,
              type = "histogram") # other customizations are available

Descalibração local

pit_local <- PIT_local(xcal = x_cal, # covariates of calibration set
                       ycal = y_cal, # outcomes of calibration set
                       yhat = y_hat_cal, # predictions of calibration set
                       mse = MSE_cal) # MSE of calibration set

gg_PIT_local(pit_local) # there are customizations available

Cobertura

Recalibração

recalibrated <- 
  recalibrate(
    pit_values = pit,      # global pit values calculated earlier.
    mse = MSE_cal,         # MSE from calibration set
    yhat_new = y_hat_test, # predictions of test set
    space_cal = x_cal,     # covariates or any representation of calibration set
    space_new = x_test,    # covariates or any representation of test set
    type = "local",        # type of calibration
    p_neighbours = 0.08)   # proportion of calibration to use as nearest neighbors

y_hat_rec <- recalibrated$y_samples_calibrated_wt
  • Pronto!
  • Esses novos valores em y_hat_rec são mais calibrados que os originais.

Bora ver?

Cobertura

Dados reais

Diamantes

Depois da Recalibração

Calibrado usando uma segunda camada escondida.

Conclusões e Trabalhos Futuros

  • Visualização eficaz da falta de calibração.
  • Vantagens em relação a outros pacotes disponíveis
    • Focado em modelos de regressão
    • Recalibração local
    • Recalibração em camadas intermediárias.

Desenvolvimentos Futuros:

  • Integração com outros pacotes (tidymodels), tipos de entrada mais amplos, métodos de validação cruzada

  • Lidar com modelos com distribuições preditivas arbitrárias.

Obrigada!

GitHub